package ioc

import (
	"fmt"
	"gitee.com/xiao_hange/go-admin-pkg/pkg/logger"
	prometheus2 "github.com/prometheus/client_golang/prometheus"
	"github.com/spf13/viper"
	"gorm.io/driver/mysql"
	"gorm.io/gorm"
	"gorm.io/gorm/utils"
	"gorm.io/plugin/opentelemetry/tracing"
	"gorm.io/plugin/prometheus"
	"time"
)

const SlowThreshold = time.Millisecond * 10

func InitDB(l logger.GormLogger) *gorm.DB {

	addr := viper.GetString("mysql.addr")
	db, err := gorm.Open(mysql.Open(addr))

	//关闭掉了 Gorm 自带的 log
	//db, err := gorm.Open(mysql.Open(addr), &gorm.Config{
	//	Logger: glogger.New(gormLoggerfunc(l.Info), glogger.Config{
	//		SlowThreshold:             time.Millisecond * 10,
	//		Colorful:                  false,
	//		IgnoreRecordNotFoundError: true,
	//		ParameterizedQueries:      false,
	//		LogLevel:                  glogger.Info,
	//	}),
	//})

	if err != nil {
		panic(err)
	}

	err = db.Use(prometheus.New(prometheus.Config{
		DBName:          "GoAdmin",
		RefreshInterval: 15,
		StartServer:     false,
		MetricsCollector: []prometheus.MetricsCollector{
			&prometheus.MySQL{
				VariableNames: []string{"thread_running"},
			},
		},
	}))

	if err != nil {
		return nil
	}
	err = db.Use(tracing.NewPlugin(
		tracing.WithDBName("GoAdmin"),
		tracing.WithQueryFormatter(func(query string) string {
			return query
		}),
		tracing.WithoutMetrics(),         // 不要记录 metrics
		tracing.WithoutQueryVariables()), // 不要记录查询参数
	)
	if err != nil {
		return nil
	}
	pcb := newCallbacks(l)
	db.Use(pcb)

	return db
}

// 这种写法试用于单方法的接口
//type gormLoggerfunc func(msg string, fields ...logger.Field)
//
//func (g gormLoggerfunc) Printf(msg string, args ...interface{}) {
//	g("", logger.Field{Key: "gorm", Val: args})
//}

type Callbacks struct {
	vector *prometheus2.SummaryVec
	l      logger.GormLogger
}

func newCallbacks(l logger.GormLogger) *Callbacks {
	vector := prometheus2.NewSummaryVec(prometheus2.SummaryOpts{
		Namespace: "SunXQ",
		Subsystem: "GoAdmin",
		Name:      "Gorm_Query_Time",
		Help:      "统计 GORM 的执行时间",
		Objectives: map[float64]float64{
			0.5:   0.01,
			0.9:   0.01,
			0.99:  0.005,
			0.999: 0.0001,
		}}, []string{"type", "table"})
	pcb := &Callbacks{
		vector: vector,
		l:      l,
	}
	prometheus2.MustRegister(vector)
	return pcb
}

func (pcb *Callbacks) before() func(db *gorm.DB) {
	return func(db *gorm.DB) {
		startTime := time.Now()
		db.Set("start_time", startTime)
	}
}

func (pcb *Callbacks) after(typ string) func(db *gorm.DB) {
	return func(db *gorm.DB) {
		val, _ := db.Get("start_time")
		startTime, ok := val.(time.Time)
		if !ok {
			//TODO
		}
		table := db.Statement.Table
		if table == "" {
			table = "unknown"
		}
		sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)

		logEntry := &GormLog{
			Position: utils.FileWithLineNum(),
			Duration: time.Since(startTime).Seconds() * 1000,
			SQL:      sql,
			Rows:     db.Statement.RowsAffected,
			l:        pcb.l,
		}
		logEntry.String()
		//c.vector.WithLabelValues(typ, table).Observe(float64(time.Since(startTime).Microseconds())) //微秒
		pcb.vector.WithLabelValues(typ, table).Observe(float64(time.Since(startTime).Milliseconds())) //毫秒
	}
}

func (pcb *Callbacks) Name() string {
	return "prometheus-query"
}

func (pcb *Callbacks) Initialize(db *gorm.DB) error {
	pcb.registerAll(db)
	return nil
}

func (pcb *Callbacks) registerAll(db *gorm.DB) {
	// 作用于 INSERT 语句
	err := db.Callback().Create().Before("*").
		Register("prometheus_create_before", pcb.before())
	if err != nil {
		panic(err)
	}
	err = db.Callback().Create().After("*").
		Register("prometheus_create_after", pcb.after("create"))
	if err != nil {
		panic(err)
	}

	err = db.Callback().Update().Before("*").
		Register("prometheus_update_before", pcb.before())
	if err != nil {
		panic(err)
	}
	err = db.Callback().Update().After("*").
		Register("prometheus_update_after", pcb.after("update"))
	if err != nil {
		panic(err)
	}

	err = db.Callback().Delete().Before("*").
		Register("prometheus_delete_before", pcb.before())
	if err != nil {
		panic(err)
	}
	err = db.Callback().Delete().After("*").
		Register("prometheus_delete_after", pcb.after("delete"))
	if err != nil {
		panic(err)
	}

	err = db.Callback().Raw().Before("*").
		Register("prometheus_raw_before", pcb.before())
	if err != nil {
		panic(err)
	}
	err = db.Callback().Raw().After("*").
		Register("prometheus_raw_after", pcb.after("raw"))
	if err != nil {
		panic(err)
	}

	err = db.Callback().Row().Before("*").
		Register("prometheus_row_before", pcb.before())
	if err != nil {
		panic(err)
	}
	err = db.Callback().Row().After("*").
		Register("prometheus_row_after", pcb.after("row"))
	err = db.Callback().Query().Before("*").
		Register("prometheus_query_before", pcb.before())
	if err != nil {
		panic(err)
	}
	err = db.Callback().Query().After("*").
		Register("prometheus_query_after", pcb.after("query"))
	if err != nil {
		panic(err)
	}
}

type GormLog struct {
	Position string
	Duration float64
	SQL      string
	Rows     int64
	l        logger.GormLogger
}

func (gl GormLog) String() {
	var logStr string
	thresholdMillis := float64(SlowThreshold.Nanoseconds()) / float64(time.Millisecond)
	//if gl.Rows == -1 {
	//	logStr = fmt.Sprintf("Position: %s | Duration: %.4fms | SQL: %s | Rows: -", gl.Position, gl.Duration, gl.SQL)
	//} else {
	//	logStr = fmt.Sprintf("Position: %s | Duration: %.4fms | SQL: %s | Rows: %d", gl.Position, gl.Duration, gl.SQL, gl.Rows)
	//}
	//
	//if thresholdMillis <= gl.Duration {
	//	logStr = fmt.Sprintf("%s | Is Slow Query SQL", logStr)
	//}
	// 不记录 gl.Rows == -1 的
	if gl.Rows > -1 {
		t := time.Now()
		logStr = fmt.Sprintf("ExecutionTime: %s | Position: %s | Duration: %.4fms | SQL: %s | Rows: %d", t.Format("2006-01-02 15:04:05.00"), gl.Position, gl.Duration, gl.SQL, gl.Rows)
		if thresholdMillis <= gl.Duration {
			logStr = fmt.Sprintf("%s | IsSlow: true", logStr)
		}
		gl.l.Info("Gorm", logger.GormString(logStr))
	}

}
